import torch
from einops import repeat, rearrange
from comfy.ldm.common_dit import pad_to_patch_size
from comfy.ldm.flux.model import Flux as FluxInnerModel
from comfy.model_base import Flux as FluxModel


def patched_flux_forward(
    self: FluxInnerModel,
    x,
    timestep,
    context,
    y,
    guidance=None,
    control=None,
    transformer_options={},
    **kwargs,
):
    bs, c, h, w = x.shape
    if c != 32:
        raise Exception(
            f"Input latent channel count {c} is not 32. The patched PhotoDoddle Flux model requires conditions generated by the PhotoDoddleConditioning node."
        )
    patch_size = self.patch_size
    x = pad_to_patch_size(x, (patch_size, patch_size))

    img = rearrange(
        x,
        "b (n c) (h ph) (w pw) -> b (n h w) (c ph pw)",
        n=2,
        ph=patch_size,
        pw=patch_size,
    )

    h_len = (h + (patch_size // 2)) // patch_size
    w_len = (w + (patch_size // 2)) // patch_size
    img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
    img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(
        0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype
    ).unsqueeze(1)
    img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(
        0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype
    ).unsqueeze(0)
    img_ids = repeat(img_ids, "h w c -> b (n h w) c", b=bs, n=2)

    txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
    out = self.forward_orig(
        img,
        img_ids,
        context,
        txt_ids,
        timestep,
        y,
        guidance,
        control,
        transformer_options,
        attn_mask=kwargs.get("attention_mask", None),
    )
    return rearrange(
        out,
        "b (n h w) (c ph pw) -> b (n c) (h ph) (w pw)",
        n=2,
        h=h_len,
        w=w_len,
        ph=2,
        pw=2,
    )[:, :16, :h, :w]


def patched_concat_cond(self: FluxModel, **kwargs):
    try:
        # Handle Flux control loras dynamically changing the img_in weight.
        num_channels = self.diffusion_model.img_in.weight.shape[1] // (
            self.diffusion_model.patch_size * self.diffusion_model.patch_size
        )
    except:
        # Some cases like tensorrt might not have the weights accessible
        num_channels = self.model_config.unet_config["in_channels"]
    out_channels = self.model_config.unet_config["out_channels"]
    if num_channels != out_channels:
        raise Exception(
            f"Model input channel count {num_channels} is not {out_channels}. Did you mistakenly load a Flux control lora?"
        )

    image = kwargs.get("concat_latent_image", None)
    device = kwargs["device"]
    image = self.process_latent_in(image.to(device))
    return image


class PhotoDoddleConditioning:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "positive": ("CONDITIONING",),
                "negative": ("CONDITIONING",),
                "vae": ("VAE",),
                "pixels": ("IMAGE",),
            }
        }

    RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "LATENT")
    RETURN_NAMES = ("model", "positive", "negative", "latent")
    FUNCTION = "encode"
    CATEGORY = "duanyll/models"

    def encode(self, model, positive, negative, pixels, vae):
        model_type_str = str(type(model.model.model_config).__name__)
        if "Flux" not in model_type_str:
            raise Exception(
                f"Attempted to patch a {model_type_str} model. PhotoDoddle is only compatible with Flux models."
            )
        model.model.diffusion_model.forward = patched_flux_forward.__get__(
            model.model.diffusion_model
        )
        model.model.concat_cond = patched_concat_cond.__get__(model.model)

        height = pixels.shape[1]
        width = pixels.shape[2]
        if height % 16 != 0 or width % 16 != 0:
            raise Exception(
                f"Image dimensions must be divisible by 16. Got ({height}, {width})."
            )

        concat_latent = vae.encode(pixels)

        def make_concat_conditioning(conditioning):
            out = []
            for t in conditioning:
                d = t[1].copy()
                d["concat_latent_image"] = concat_latent
                n = [t[0], d]
                out.append(n)
            return out

        latent = {"samples": torch.randn_like(concat_latent)}

        return (
            model,
            make_concat_conditioning(positive),
            make_concat_conditioning(negative),
            latent,
        )